##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2018, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################

"""
Implementation of Driver class
It is a wrapper around the actual psycopg2 driver, and connection
object.

"""
import datetime
from flask import session
from flask_babelex import gettext
import psycopg2
from psycopg2.extensions import adapt

import config
from pgadmin.model import Server, User
from .keywords import ScanKeyword
from ..abstract import BaseDriver
from .connection import Connection
from .server_manager import ServerManager


class Driver(BaseDriver):
    """
    class Driver(BaseDriver):

    This driver acts as a wrapper around psycopg2 connection driver
    implementation. We will be using psycopg2 for makeing connection with
    the PostgreSQL/EDB Postgres Advanced Server (EnterpriseDB).

    Properties:
    ----------

    * Version (string):
        Version of psycopg2 driver

    Methods:
    -------
    * get_connection(sid, database, conn_id, auto_reconnect)
    - It returns a Connection class object, which may/may not be connected
      to the database server for this sesssion

    * release_connection(seid, database, conn_id)
    - It releases the connection object for the given conn_id/database for this
      session.

    * connection_manager(sid, reset)
    - It returns the server connection manager for this session.
    """

    def __init__(self, **kwargs):
        self.managers = dict()

        super(Driver, self).__init__()

    def connection_manager(self, sid=None):
        """
        connection_manager(...)

        Returns the ServerManager object for the current session. It will
        create new ServerManager object (if necessary).

        Parameters:
            sid
            - Server ID
        """
        assert (sid is not None and isinstance(sid, int))
        managers = None

        if session.sid not in self.managers:
            self.managers[session.sid] = managers = dict()
            if '__pgsql_server_managers' in session:
                session_managers = session['__pgsql_server_managers'].copy()
                session['__pgsql_server_managers'] = dict()

                for server_id in session_managers:
                    s = Server.query.filter_by(id=server_id).first()

                    if not s:
                        continue

                    manager = managers[str(server_id)] = ServerManager(s)
                    manager._restore(session_managers[server_id])
                    manager.update_session()
        else:
            managers = self.managers[session.sid]

        managers['pinged'] = datetime.datetime.now()
        if str(sid) not in managers:
            s = Server.query.filter_by(id=sid).first()

            if not s:
                return None

            managers[str(sid)] = ServerManager(s)

            return managers[str(sid)]

        return managers[str(sid)]

    def Version(cls):
        """
        Version(...)

        Returns the current version of psycopg2 driver
        """
        version = getattr(psycopg2, '__version__', None)

        if version:
            return version

        raise Exception(
            "Driver Version information for psycopg2 is not available!"
        )

    def libpq_version(cls):
        """
        Returns the loaded libpq version
        """
        version = getattr(psycopg2, '__libpq_version__', None)
        if version:
            return version

        raise Exception(
            "libpq version information is not available!"
        )

    def get_connection(
            self, sid, database=None, conn_id=None, auto_reconnect=True
    ):
        """
        get_connection(...)

        Returns the connection object for the certain connection-id/database
        for the specific server, identified by sid. Create a new Connection
        object (if require).

        Parameters:
            sid
            - Server ID
            database
            - Database, on which the connection needs to be made
              If provided none, maintenance_db for the server will be used,
              while generating new Connection object.
            conn_id
            - Identification String for the Connection This will be used by
              certain tools, which will require a dedicated connection for it.
              i.e. Debugger, Query Tool, etc.
            auto_reconnect
            - This parameters define, if we should attempt to reconnect the
              database server automatically, when connection has been lost for
              any reason. Certain tools like Debugger will require a permenant
              connection, and it stops working on disconnection.

        """
        manager = self.connection_manager(sid)

        return manager.connection(database, conn_id, auto_reconnect)

    def release_connection(self, sid, database=None, conn_id=None):
        """
        Release the connection for the given connection-id/database in this
        session.
        """
        return self.connection_manager(sid).release(database, conn_id)

    def delete_manager(self, sid):
        """
        Delete manager for given server id.
        """
        manager = self.connection_manager(sid)
        if manager is not None:
            manager.release()
        if session.sid in self.managers and \
                str(sid) in self.managers[session.sid]:
            del self.managers[session.sid][str(sid)]

    def gc(self):
        """
        Release the connections for the sessions, which have not pinged the
        server for more than config.MAX_SESSION_IDLE_TIME.
        """

        # Minimum session idle is 20 minutes
        max_idle_time = max(config.MAX_SESSION_IDLE_TIME or 60, 20)
        session_idle_timeout = datetime.timedelta(minutes=max_idle_time)

        curr_time = datetime.datetime.now()

        for sess in self.managers:
            sess_mgr = self.managers[sess]

            if sess == session.sid:
                sess_mgr['pinged'] = curr_time
                continue

            if curr_time - sess_mgr['pinged'] >= session_idle_timeout:
                for mgr in [
                        m for m in sess_mgr if isinstance(m, ServerManager)
                ]:
                    mgr.release()

    @staticmethod
    def qtLiteral(value):
        adapted = adapt(value)

        # Not all adapted objects have encoding
        # e.g.
        # psycopg2.extensions.BOOLEAN
        # psycopg2.extensions.FLOAT
        # psycopg2.extensions.INTEGER
        # etc...
        if hasattr(adapted, 'encoding'):
            adapted.encoding = 'utf8'
        res = adapted.getquoted()

        if isinstance(res, bytes):
            return res.decode('utf-8')
        return res

    @staticmethod
    def ScanKeywordExtraLookup(key):
        # UNRESERVED_KEYWORD      0
        # COL_NAME_KEYWORD        1
        # TYPE_FUNC_NAME_KEYWORD  2
        # RESERVED_KEYWORD        3
        extraKeywords = {
            'connect': 3,
            'convert': 3,
            'distributed': 0,
            'exec': 3,
            'log': 0,
            'long': 3,
            'minus': 3,
            'nocache': 3,
            'number': 3,
            'package': 3,
            'pls_integer': 3,
            'raw': 3,
            'return': 3,
            'smalldatetime': 3,
            'smallfloat': 3,
            'smallmoney': 3,
            'sysdate': 3,
            'systimestap': 3,
            'tinyint': 3,
            'tinytext': 3,
            'varchar2': 3
        }

        return extraKeywords.get(key, None) or ScanKeyword(key)

    @staticmethod
    def needsQuoting(key, forTypes):
        value = key
        valNoArray = value

        # check if the string is number or not
        if isinstance(value, int):
            return True
        # certain types should not be quoted even though it contains a space.
        # Evilness.
        elif forTypes and value[-2:] == u"[]":
            valNoArray = value[:-2]

        if forTypes and valNoArray.lower() in [
            u'bit varying',
            u'"char"',
            u'character varying',
            u'double precision',
            u'timestamp without time zone',
            u'timestamp with time zone',
            u'time without time zone',
            u'time with time zone',
            u'"trigger"',
            u'"unknown"'
        ]:
            return False

        # If already quoted?, If yes then do not quote again
        if forTypes and valNoArray:
            if valNoArray.startswith('"') \
                    or valNoArray.endswith('"'):
                return False

        if u'0' <= valNoArray[0] <= u'9':
            return True

        for c in valNoArray:
            if (not (u'a' <= c <= u'z') and c != u'_' and
                    not (u'0' <= c <= u'9')):
                return True

        # check string is keywaord or not
        category = Driver.ScanKeywordExtraLookup(value)

        if category is None:
            return False

        # UNRESERVED_KEYWORD
        if category == 0:
            return False

        # COL_NAME_KEYWORD
        if forTypes and category == 1:
            return False

        return True

    @staticmethod
    def qtTypeIdent(conn, *args):
        # We're not using the conn object at the moment, but - we will
        # modify the
        # logic to use the server version specific keywords later.
        res = None
        value = None

        for val in args:
            if len(val) == 0:
                continue
            if hasattr(str, 'decode') and not isinstance(val, unicode):
                # Handling for python2
                try:
                    val = str(val).encode('utf-8')
                except UnicodeDecodeError:
                    # If already unicode, most likely coming from db
                    val = str(val).decode('utf-8')
            value = val

            if (Driver.needsQuoting(val, True)):
                value = value.replace("\"", "\"\"")
                value = "\"" + value + "\""

            res = ((res and res + '.') or '') + value

        return res

    @staticmethod
    def qtIdent(conn, *args):
        # We're not using the conn object at the moment, but - we will
        # modify the logic to use the server version specific keywords later.
        res = None
        value = None

        for val in args:
            if type(val) == list:
                return map(lambda w: Driver.qtIdent(conn, w), val)
            if hasattr(str, 'decode') and not isinstance(val, unicode):
                # Handling for python2
                try:
                    val = str(val).encode('utf-8')
                except UnicodeDecodeError:
                    # If already unicode, most likely coming from db
                    val = str(val).decode('utf-8')

            if len(val) == 0:
                continue

            value = val

            if (Driver.needsQuoting(val, False)):
                value = value.replace("\"", "\"\"")
                value = "\"" + value + "\""

            res = ((res and res + '.') or '') + value

        return res
